from __future__ import annotations

import torch
import torch.nn as nn

@torch.no_grad()
def step(update: dict[str, torch.Tensor], global_model: nn.Module) -> None:
    model_state: dict[str, torch.Tensor] = global_model.state_dict()
    for name in model_state:
        state_type = model_state[name].dtype
        model_state[name] += update[name].to(dtype=state_type)
    global_model.load_state_dict(model_state)
